package com.lwl.concurrency.forkjoin;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.TimeUnit;

/**
 * Created by liwenlong on 2018/1/12 22:26
 */
public class ForkJoinActionTest {
    public static class Product {
        private String name;
        private double price;

        public String getName() {
            return name;
        }

        public void setName(String name) {
            this.name = name;
        }

        public double getPrice() {
            return price;
        }

        public void setPrice(double price) {
            this.price = price;
        }
    }

    /**
     * 生成一个随机产品列表
     */
    public static class ProductListGenerator {
        public List<Product> generate(int size) {
            List<Product> list = new ArrayList<>();
            for (int i = 0; i < size; i++) {
                Product product = new Product();
                product.setName("Product-" + i);
                product.setPrice(10);
                list.add(product);
            }
            return list;
        }
    }

    /**
     * 没有发返回的任务
     */
    public static class Task extends RecursiveAction {

        private static final long serialVersionUID = -5244147827409479021L;
        private List<Product> products;
        private int first;
        private int last;
        private double increment;

        /**
         * @param products  产品集合
         * @param first     更新的数据的开始索引
         * @param last      更新的数据的结束索引
         * @param increment 更新的增加额度
         */
        public Task(List<Product> products, int first, int last, double increment) {
            this.products = products;
            this.first = first;
            this.last = last;
            this.increment = increment;
        }

        @Override
        protected void compute() {
            if (last - first < 10) {
                updatePrices();
            } else {
                int middle = (last + first) / 2;
                System.out.println(Thread.currentThread().getName()+" task:pending tasks:" + getQueuedTaskCount());
                Task task1 = new Task(products, first, middle + 1, increment);
                Task task2 = new Task(products, middle + 1, last, increment);
                invokeAll(task1, task2);
                System.out.println("阻塞结束");
            }
        }

        /**
         * 用于更新在产品列表中处于first和last之间的产品
         */
        private void updatePrices() {
            try {
                TimeUnit.SECONDS.sleep(3);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            for (int i = first; i < last; i++) {
                Product product = products.get(i);
                product.setPrice(product.getPrice() * (1 + increment));
            }
        }
    }

    public static void main(String[] args) {
        ProductListGenerator generator = new ProductListGenerator();
        List<Product> products = generator.generate(10000);
        Task task = new Task(products, 0, products.size(), 0.2);

        /*
         创建一个线程数等于计算机cpu数目的线程池
         */
        ForkJoinPool pool = new ForkJoinPool();
        //异步调用
        pool.execute(task);
        //同步调用
        //pool.invoke(task);
       /* do {
            System.out.printf("Main: Thread Count: %d\n", pool.getActiveThreadCount());
            System.out.printf("Main: Thread Steal: %d\n", pool.getStealCount());
            System.out.printf("Main: Parallelism: %d\n", pool.getParallelism());
            try {
                TimeUnit.MILLISECONDS.sleep(5);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } while (!task.isDone());*/

        //等待任务完成
        task.join();

        pool.shutdown();
        if (task.isCompletedNormally()) {
            System.out.printf("Main: The process has completed normally.\n");
        }
        for (int i = 0; i < products.size(); i++) {
            Product product = products.get(i);
            if (product.getPrice() != 12) {
                System.out.printf("Product................... %s: %f\n", product.getName(), product.getPrice());
            }
        }
        System.out.println("Main: End of the program.\n");
    }

}
